{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e66c7d67581b"
      },
      "source": [
        "# Fine-Tuning CodeGemma on the SQL Spider Dataset\n",
        "**Author**: Carlo Fisicaro  \n",
        "**GitHub**: [github.com/carlofisicaro](https://github.com/carlofisicaro)  \n",
        "**X**: [@carlo_fisicaro](https://twitter.com/carlo_fisicaro)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dfsDR_omdNea"
      },
      "source": [
        "# CodeGemma text-to-sql (Hugging Face)\n",
        "This notebook demonstrates how to load, fine-tune and deploy CodeGemma model on SQL by utilising Hugging Face.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/CodeGemma/[CodeGemma_1]Finetune_with_SQL.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FaqZItBdeokU"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "### CodeGemma setup\n",
        "\n",
        "**Before we dive into the tutorial, let's get you set up with CodeGemma:**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **CodeGemma Model Access:** Head over to the [CodeGemma model page](google/codegemma-7b-it) and accept the usage conditions.\n",
        "3. **Colab with Gemma Power:**  For this tutorial, you'll need a Colab runtime with enough resources to handle the Gemma 2B model. Choose an appropriate runtime when starting your Colab session.\n",
        "4. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where we'll set up environment variables in your Colab environment.**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CY2kGtsyYpHF"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret.\n"
      ]
    },
    {
      "cell_type": "raw",
      "metadata": {
        "id": "A9sUQ4WrP-Yr"
      },
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iwjo5_Uucxkw"
      },
      "source": [
        "### Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "raw",
      "metadata": {
        "id": "r_nXPEsF7UWQ"
      },
      "source": [
        "!pip install --upgrade -q transformers huggingface_hub peft \\\n",
        "  accelerate bitsandbytes datasets trl evaluate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2_bahJBmwvSp"
      },
      "source": [
        "### Log into Hugging Face Hub\n"
      ]
    },
    {
      "cell_type": "raw",
      "metadata": {
        "id": "GIFFCHi-wvSp"
      },
      "source": [
        "from huggingface_hub import login\n",
        "\n",
        "login(os.environ[\"HF_TOKEN\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gFLddpGeaKh5"
      },
      "source": [
        "All set and ready to explore the possibilities with Gemma!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yXFZFUJHgTcU"
      },
      "source": [
        "## Instantiate the CodeGemma 7B model\n",
        "\n",
        "CodeGemma is a collection of powerful, lightweight models that can perform a variety of coding tasks like fill-in-the-middle code completion, code generation, natural language understanding, mathematical reasoning, and instruction following.\n",
        "Her we're importing the 7B instruction-tuned variant for natural language-to-code chat and instruction following.\n",
        "\n",
        "\n",
        "Let's get started by loading the model from Hugging Face Hub."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jgl8ZjHpwvSq"
      },
      "source": [
        "### Loading the model from HF Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w_z4600bwvSq"
      },
      "outputs": [],
      "source": [
        "model_id = \"google/codegemma-7b-it\"\n",
        "device = \"cuda\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "74tpQWWWwvSq"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n"
          ]
        }
      ],
      "source": [
        "# Let's load the tokenizer first\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UD-eXTxxwvSq"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.38s/it]\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import (\n",
        "    AutoModelForCausalLM,\n",
        "    BitsAndBytesConfig,\n",
        ")\n",
        "\n",
        "# Let's quantize the model to reduce its weight\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        ")\n",
        "\n",
        "# Let's load the final model\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    quantization_config=bnb_config,\n",
        "    device_map={\"\": 0},\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ba332aee469"
      },
      "source": [
        "Let's define a preamble so that our models understands we want to get SQL queries out of it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lyw7fwOGwvSq"
      },
      "source": [
        "### Trying it out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b8be32c8f87b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Tell me the title of the product list page with the highest conversion rate to detail pages in February 2021.\n",
            "\n",
            "The product list page with the highest conversion rate to detail pages in February 2021 is the **Women's Clothing** page. This page had a conversion rate of **2.5%**, which means that for every 100 visitors to the page, 2.5 of them clicked through to a product detail page.\n",
            "\n",
            "This is a significant conversion rate, and it suggests that the Women's Clothing page is doing a good job of converting visitors into customers. The page features a wide variety of products, from dresses and skirts to jeans and sweaters, and it also provides a variety of helpful features, such as product filters and a search bar. These features make it easy for visitors to find the products they are looking for, and they are also likely to contribute to the high conversion rate.\n"
          ]
        }
      ],
      "source": [
        "prompt = (\n",
        "    \"Tell me the title of the product list page with the highest conversion \"\n",
        "    \"rate to detail pages in February 2021.\"\n",
        ")\n",
        "\n",
        "inputs = tokenizer.encode(\n",
        "    prompt,\n",
        "    return_tensors=\"pt\"\n",
        ").to(device)\n",
        "\n",
        "outputs = model.generate(\n",
        "    inputs,\n",
        "    max_new_tokens=500\n",
        ")\n",
        "\n",
        "text = tokenizer.decode(\n",
        "    outputs[0],\n",
        "    skip_special_tokens=True\n",
        ")\n",
        "\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1e7f17d5ff43"
      },
      "source": [
        "Let's ask an ambiguous question to CodeGemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nrVBVTtlwvSq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What is the place with the zip code in which the average mean sea level pressure is the lowest? Generate the SQL query with python.\n",
            "\n",
            "```python\n",
            "import pandas as pd\n",
            "import sqlalchemy as sa\n",
            "\n",
            "# Create a connection to the database\n",
            "engine = sa.create_engine('postgresql://postgres:password@localhost:5432/postgres')\n",
            "\n",
            "# Create a query to get the average mean sea level pressure for each zip code\n",
            "query = \"\"\"\n",
            "SELECT zip_code, AVG(mean_sea_level_pressure) AS avg_pressure\n",
            "FROM weather_data\n",
            "GROUP BY zip_code\n",
            "ORDER BY avg_pressure ASC\n",
            "LIMIT 1;\n",
            "\"\"\"\n",
            "\n",
            "# Execute the query and store the results in a DataFrame\n",
            "df = pd.read_sql_query(query, engine)\n",
            "\n",
            "# Print the zip code with the lowest average mean sea level pressure\n",
            "print(df['zip_code'].iloc[0])\n",
            "```\n"
          ]
        }
      ],
      "source": [
        "prompt = (\n",
        "    \"What is the place with the zip code in which the average mean sea level \"\n",
        "    \"pressure is the lowest? Generate the SQL query with python.\"\n",
        ")\n",
        "\n",
        "inputs = tokenizer.encode(\n",
        "    prompt,\n",
        "    return_tensors=\"pt\"\n",
        ").to(device)\n",
        "\n",
        "outputs = model.generate(\n",
        "    inputs,\n",
        "    max_new_tokens=200\n",
        ")\n",
        "\n",
        "text = tokenizer.decode(\n",
        "    outputs[0],\n",
        "    skip_special_tokens=True\n",
        ")\n",
        "\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "99ec55132cf1"
      },
      "source": [
        "The question is ambiguous because it's not clear whether we're asking for:\n",
        "* a python script producing a SQL query\n",
        "* two separate scripts producing respectively, python and SQL code.\n",
        "\n",
        "CodeGemma picked the first option. Bear it in mind!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QlFBTx33wvSq"
      },
      "source": [
        "## Fine-tuning the model with LoRA\n",
        "\n",
        "This section of the guide focuses on training your Large Language Model (LLM) to generate SQL code from natural language. Here, we will explore the process of fine-tuning your model to enable it to produce high quality SQL queries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8_iH8JINwvSr"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Example item: {'db_id': 'department_management', 'query': 'SELECT count(*) FROM head WHERE age  >  56', 'question': 'How many heads of the departments are older than 56 ?', 'query_toks': ['SELECT', 'count', '(', '*', ')', 'FROM', 'head', 'WHERE', 'age', '>', '56'], 'query_toks_no_value': ['select', 'count', '(', '*', ')', 'from', 'head', 'where', 'age', '>', 'value'], 'question_toks': ['How', 'many', 'heads', 'of', 'the', 'departments', 'are', 'older', 'than', '56', '?']}\n"
          ]
        }
      ],
      "source": [
        "# Loading and processing the spider dataset\n",
        "from datasets import load_dataset\n",
        "\n",
        "# data = load_dataset(\"xlangai/spider\")\n",
        "data = load_dataset(\"xlangai/spider\")\n",
        "print(\"Example item:\", data[\"train\"][0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "712100e0d160"
      },
      "source": [
        "We need to define a function to tokenize the input. Let's tokenize the 'question' and 'query' columns for training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a88ac2e1982a"
      },
      "outputs": [],
      "source": [
        "import sqlparse\n",
        "\n",
        "\n",
        "# Formatting function to preprocess the data\n",
        "def formatting_func(samples):\n",
        "    questions_with_preamble = [\n",
        "        f\"{question} SQL:\" for question in samples[\"question\"]\n",
        "    ]\n",
        "\n",
        "    sql_queries = []\n",
        "    for query in samples[\"query\"]:\n",
        "        sql_query = sqlparse.format(\n",
        "            query, reindent=True, keyword_case='upper'\n",
        "        )\n",
        "        sql_queries.append(sql_query)\n",
        "\n",
        "    formatted_queries = [\n",
        "        f\"```sql\\n{query}\\n```\" for query in sql_queries\n",
        "    ]\n",
        "\n",
        "    return {\n",
        "        \"questions\": questions_with_preamble,\n",
        "        \"queries\": formatted_queries\n",
        "    }\n",
        "\n",
        "\n",
        "# Tokenization function\n",
        "def tokenize_function(samples):\n",
        "    max_length = 1024  # Set a reasonable max_length based on your data\n",
        "\n",
        "    inputs = tokenizer(\n",
        "        samples[\"questions\"],\n",
        "        truncation=True,\n",
        "        padding=\"max_length\",\n",
        "        max_length=max_length,\n",
        "        return_tensors=\"pt\"\n",
        "    )\n",
        "\n",
        "    outputs = tokenizer(\n",
        "        samples[\"queries\"],\n",
        "        truncation=True,\n",
        "        padding=\"max_length\",\n",
        "        max_length=max_length,\n",
        "        return_tensors=\"pt\"\n",
        "    )\n",
        "\n",
        "    return {\n",
        "        \"input_ids\": inputs[\"input_ids\"],\n",
        "        \"labels\": outputs[\"input_ids\"]\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65Cmr3fBw9bN"
      },
      "outputs": [],
      "source": [
        "# Apply the formatting function to the dataset\n",
        "data = data.map(formatting_func, batched=True)\n",
        "\n",
        "# Apply the tokenization function to the formatted data\n",
        "data = data.map(tokenize_function, batched=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C8RKs_oZwvSr"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig\n",
        "\n",
        "# Define tuning parameters\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=[\n",
        "        \"q_proj\",\n",
        "        \"o_proj\",\n",
        "        \"k_proj\",\n",
        "        \"v_proj\",\n",
        "        \"gate_proj\",\n",
        "        \"up_proj\",\n",
        "        \"down_proj\",\n",
        "    ],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9b75c6384048"
      },
      "outputs": [],
      "source": [
        "train_data = data[\"train\"].shuffle(seed=1234).select(range(100))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oysHd0jXwvSr"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:292: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:493: UserWarning: You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:396: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
            "  warnings.warn(\n",
            "max_steps is given, it will override any value given in num_train_epochs\n"
          ]
        }
      ],
      "source": [
        "import transformers\n",
        "from trl import SFTTrainer\n",
        "\n",
        "# Create Trainer objects that takes care of the process\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=train_data,\n",
        "    args=transformers.TrainingArguments(\n",
        "        per_device_train_batch_size=1,\n",
        "        gradient_accumulation_steps=4,\n",
        "        warmup_steps=2,\n",
        "        max_steps=50,\n",
        "        learning_rate=2e-4,\n",
        "        fp16=True,\n",
        "        output_dir=\"outputs\",\n",
        "        logging_dir=\"./logs\",\n",
        "        logging_strategy=\"steps\",\n",
        "        logging_steps=1,\n",
        "        optim=\"paged_adamw_8bit\",\n",
        "    ),\n",
        "    peft_config=lora_config,\n",
        "    formatting_func=formatting_func,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yCeOevVHsJGX"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='50' max='50' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [50/50 01:50, Epoch 2/2]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>152.741500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>109.213100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>164.229600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>120.124200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>139.504700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>89.775900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>110.598600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>118.878200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>81.384500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>114.521900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>92.502300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>88.829600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>104.187100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>111.698900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>67.454400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>110.968400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>72.439300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>60.063300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>65.954000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>87.239500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>64.619300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>58.157300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>64.746200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>61.815700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>45.756500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>53.231300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>39.935800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>32.591100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>33.389000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>35.493500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>33.897600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>27.153400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>29.196500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>24.632600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>20.031000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>17.797800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>19.721600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>13.950200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>10.526600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>10.215900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>9.000700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>7.255400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>6.684400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>7.408300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>7.284500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>7.613000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>7.755100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>7.623400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>5.783600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>7.106700</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=50, training_loss=56.65366875648498, metrics={'train_runtime': 112.6611, 'train_samples_per_second': 1.775, 'train_steps_per_second': 0.444, 'total_flos': 9555457081344000.0, 'train_loss': 56.65366875648498, 'epoch': 2.0})"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Let's run the fine-tuning\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5b74f2a7bea6"
      },
      "source": [
        "Let's ask the same ambiguous question to our CodeGemma finetuned on SQL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y1hPDZgZwvSr"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.2` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What is the place with the zip code in which the average mean sea level pressure is the lowest? Generate the SQL query with python code to find the answer.\n",
            "\n",
            "```sql\n",
            "SELECT zip_code, avg(mean_sea_level_pressure) AS average_pressure\n",
            "FROM weather_data\n",
            "GROUP BY zip_code\n",
            "ORDER BY average_pressure ASC\n",
            "LIMIT 1;\n",
            "```\n",
            "\n",
            "```python\n",
            "import pandas as pd\n",
            "\n",
            "# Read the weather data from a CSV file\n",
            "weather_data = pd.read_csv('weather_data.csv')\n",
            "\n",
            "# Group the data by zip code and calculate the average mean sea level pressure for each zip code\n",
            "average_pressure_by_zip_code = weather_data.groupby('zip_code')['mean_sea_level_pressure'].mean()\n",
            "\n",
            "# Find the zip code with the lowest average mean sea level pressure\n",
            "zip_code_with_lowest_average_pressure = average_pressure_by_zip_code.idxmin()\n",
            "\n",
            "# Print the zip code with the lowest average mean sea level pressure\n",
            "print(zip_code_with_lowest_average_pressure)\n",
            "```\n"
          ]
        }
      ],
      "source": [
        "# Testing the models after fine-tuning\n",
        "text = (\n",
        "    \"What is the place with the zip code in which the average mean sea level \"\n",
        "    \"pressure is the lowest? Generate the SQL query with python\"\n",
        ")\n",
        "\n",
        "inputs = tokenizer(\n",
        "    text,\n",
        "    return_tensors=\"pt\"\n",
        ").to(device)\n",
        "\n",
        "outputs = model.generate(\n",
        "    **inputs,\n",
        "    max_length=300,\n",
        "    temperature=0.2,  # Low temperature for deterministic output\n",
        "    top_k=50,  # Limits the randomness\n",
        ")\n",
        "\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b526d4d3f1a1"
      },
      "source": [
        "This time the model picked the second option providing two separate scripts producing respectively, python and SQL code!\n",
        "\n",
        "The model knows we 'prefer' to get a SQL query now but it didn't forget the other programming languages it's been trained on."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p3jL-Z8CtqgP"
      },
      "source": [
        "## Push the model to your Hugging Face Hub\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aM84Ti3r02Tz"
      },
      "source": [
        "Hugging Face allow to you easily store trained models in their hub."
      ]
    },
    {
      "cell_type": "raw",
      "metadata": {
        "id": "HIDWBva0_SX4"
      },
      "source": [
        "# Note: The token needs to have \"write\" permission\n",
        "#       You can check it here:\n",
        "#       https://huggingface.co/settings/tokens\n",
        "model.push_to_hub(\"my-codegemma-7-finetuned-model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5e-PkWR6wvSr"
      },
      "source": [
        "## Serve you model using Text Generation Inference (TGI)\n",
        "\n",
        "Text Generation Inference is a toolkit that simplifies deploying and using large language models (LLMs) like Gemma. It optimizes models for text generation tasks, enabling them to run faster and produce results quicker. TGI achieves this through techniques like tensor parallelism, which distributes the workload across multiple graphics cards (GPUs) for faster processing, and optimized code specifically designed for text generation. Additionally, TGI offers features that make it suitable for production environments, such as distributed tracing for monitoring model performance, Prometheus metrics for detailed data collection, and security measures like watermarking to protect model outputs. You can read more about TGI by referring to [the official documentation](https://huggingface.co/docs/text-generation-inference/en/index)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bmx6iT6xp0RI"
      },
      "source": [
        "To deploy your model with TGI you can either:\n",
        "\n",
        "1. **Deploy it locally (requires Docker):** Uncomment the code cells below to run the model on your local machine. This approach requires Docker to be installed and GPU attached.\n",
        "\n",
        "2. **Deploy it on Google Cloud Platform using GKE:** Follow this guide [Serve Gemma open models using GPUs on GKE with Hugging Face TGI](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-gpu-tgi) to deploy your model on Google Cloud's CKE service. This option leverages GPUs for high-performance inference.\n",
        "\n",
        "Both deployment methods will provide you with an HTTP endpoint for sending requests and receiving text generation responses from your model."
      ]
    },
    {
      "cell_type": "raw",
      "metadata": {
        "id": "0wEjhtJawvSr"
      },
      "source": [
        "!model=\"google/codegemma-7b-it\" # ID of the model in Hugging Face hub\n",
        "# (you can use your own fine-tuned model from\n",
        "# the previous step)\n",
        "!volume=$PWD/data               # Shared directory with the Docker container\n",
        "# to avoid downloading weights every run\n",
        "\n",
        "# !docker run --gpus all --shm-size 1g -p 8080:80 \\\n",
        "#     -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \\\n",
        "#     --model-id $model"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[CodeGemma_1]Finetune_with_SQL.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
